
use crate::{array::{arr::WArr, data::Data, error::{WError, WResult}}, dtype::cpu::F32};




#[test]
fn test_from_vec1() -> WResult<()> {
    let data1 = WArr::<F32>::from_vec1(vec![1.0f32; 10])?;
    let data2 = WArr::<F32>::from_vec2(vec![vec![2.0f32; 20]; 10])?;
    let t = data1.add(&data2);
    assert_eq!(t.err(), Some(WError::ShapeNumMisMatch(1, 2)));
    Ok(())
}

#[test]
fn test_from_vec2() -> WResult<()> {
    let data1: WArr<F32> = WArr::from_vec1(vec![1.0f32; 10])?;
    let data2 = WArr::from_vec2(vec![vec![2.0f32; 20]; 10])?;
    let t = data1.add(&data2);
    assert_eq!(t.err(), Some(WError::ShapeNumMisMatch(1, 2)));
    Ok(())
}


#[test]
fn test_from_shape() -> WResult<()> {
    let data1: WArr<F32> = WArr::from_shape(10, 1.0f32)?;
    let data2 = WArr::from_shape((10, 20), 1.0f32)?;
    let t = data1.add(&data2);
    assert_eq!(t.err(), Some(WError::ShapeNumMisMatch(1, 2)));
    Ok(())
}


#[test]
fn test_vec_shape3() -> WResult<()> {
    let data1: WArr<F32> = WArr::ones(10)?;
    let data2 = WArr::ones((10, 20))?;
    let t = data1.add(&data2);
    assert_eq!(t.err(), Some(WError::ShapeNumMisMatch(1, 2)));
    Ok(())
}


#[test]
fn test_vec2_matmul() -> WResult<()> {
    let data1: WArr<F32> = WArr::from_vec2(vec![
        vec![1.0f32, 2.0, 3.0],
        vec![4.0, 5.0, 6.0],
        vec![7.0, 8.0, 9.0],
        vec![10.0, 11.0, 12.0],
    ])?;
    let data2 = WArr::from_vec2(vec![
        vec![10.0f32, 20.0, 30.0, 40.0, 50.0],
        vec![60.0, 70.0, 80.0, 90.0, 100.0],
        vec![110.0, 120.0, 130.0, 140.0, 150.0],
    ])?;
    let data3 = WArr::from_vec2(vec![
        vec![ 460.,  520.,  580.,  640.,  700.],
        vec![1000., 1150., 1300., 1450., 1600.],
        vec![1540., 1780., 2020., 2260., 2500.],
        vec![2080., 2410., 2740., 3070., 3400.],
    ])?;
    let t = data1.matmul(&data2)?;
    assert_eq!(t, data3);
    Ok(())
}


#[test]
fn test_vec3_matmul() -> WResult<()> {
    let data1 = (0..120).collect::<Vec<_>>();
    let data2 = (200..320).collect::<Vec<_>>();
    let data1: WArr<F32> = WArr::from_vec1_f64(data1)?.reshape((2, 4, 3, 5))?;
    let data2 = WArr::from_vec1_f64(data2)?.reshape((2, 4, 5, 3))?;

    let data3 = vec![vec![vec![vec![2090, 2100, 2110],
    vec![7240, 7275, 7310],
    vec![12390, 12450, 12510],],
    vec![vec![18815, 18900, 18985],
    vec![24340, 24450, 24560],
    vec![29865, 30000, 30135],],
    vec![vec![37790, 37950, 38110],
    vec![43690, 43875, 44060],
    vec![49590, 49800, 50010],],
    vec![vec![59015, 59250, 59485],
    vec![65290, 65550, 65810],
    vec![71565, 71850, 72135],],],
    vec![vec![vec![82490, 82800, 83110],
    vec![89140, 89475, 89810],
    vec![95790, 96150, 96510],],
    vec![vec![108215, 108600, 108985],
    vec![115240, 115650, 116060],
    vec![122265, 122700, 123135],],
    vec![vec![136190, 136650, 137110],
    vec![143590, 144075, 144560],
    vec![150990, 151500, 152010],],
    vec![vec![166415, 166950, 167485],
    vec![174190, 174750, 175310],
    vec![181965, 182550, 183135],],],];
    let data3 = WArr::from_vec4_f64(data3)?;
    let t = data1.matmul(&data2)?;
    assert_eq!(t, data3);
    Ok(())
}




#[test]
fn test_vec4_transpose1() -> WResult<()> {
    let data = (0..210).collect::<Vec<_>>();
    let data: WArr<F32> = WArr::from_vec1_f64(data)?.reshape((2, 3, 5, 7))?;
    let t = data.transpose(1, 2)?;
    let data2 = vec![vec![vec![vec![0,1,2,3,4,5,6],vec![35,36,37,38,39,40,41],vec![70,71,72,73,74,75,76]],vec![vec![7,8,9,10,11,12,13],vec![42,43,44,45,46,47,48],vec![77,78,79,80,81,82,83]],vec![vec![14,15,16,17,18,19,20],vec![49,50,51,52,53,54,55],vec![84,85,86,87,88,89,90]],vec![vec![21,22,23,24,25,26,27],vec![56,57,58,59,60,61,62],vec![91,92,93,94,95,96,97]],vec![vec![28,29,30,31,32,33,34],vec![63,64,65,66,67,68,69],vec![98,99,100,101,102,103,104]]],vec![vec![vec![105,106,107,108,109,110,111],vec![140,141,142,143,144,145,146],vec![175,176,177,178,179,180,181]],vec![vec![112,113,114,115,116,117,118],vec![147,148,149,150,151,152,153],vec![182,183,184,185,186,187,188]],vec![vec![119,120,121,122,123,124,125],vec![154,155,156,157,158,159,160],vec![189,190,191,192,193,194,195]],vec![vec![126,127,128,129,130,131,132],vec![161,162,163,164,165,166,167],vec![196,197,198,199,200,201,202]],vec![vec![133,134,135,136,137,138,139],vec![168,169,170,171,172,173,174],vec![203,204,205,206,207,208,209]]]];
    let data2 = WArr::from_vec4_f64(data2)?;
    assert_eq!(t, data2);
    Ok(())
}



#[test]
fn test_vec4_transpose2() -> WResult<()> {
    let data = (0..210).collect::<Vec<_>>();
    let data: WArr<F32> = WArr::from_vec1_f64(data)?.reshape((2, 3, 5, 7))?;
    let t = data.transpose(0, 1)?;
    let data2 = vec![vec![vec![vec![0,1,2,3,4,5,6],vec![7,8,9,10,11,12,13],vec![14,15,16,17,18,19,20],vec![21,22,23,24,25,26,27],vec![28,29,30,31,32,33,34]],vec![vec![105,106,107,108,109,110,111],vec![112,113,114,115,116,117,118],vec![119,120,121,122,123,124,125],vec![126,127,128,129,130,131,132],vec![133,134,135,136,137,138,139]]],vec![vec![vec![35,36,37,38,39,40,41],vec![42,43,44,45,46,47,48],vec![49,50,51,52,53,54,55],vec![56,57,58,59,60,61,62],vec![63,64,65,66,67,68,69]],vec![vec![140,141,142,143,144,145,146],vec![147,148,149,150,151,152,153],vec![154,155,156,157,158,159,160],vec![161,162,163,164,165,166,167],vec![168,169,170,171,172,173,174]]],vec![vec![vec![70,71,72,73,74,75,76],vec![77,78,79,80,81,82,83],vec![84,85,86,87,88,89,90],vec![91,92,93,94,95,96,97],vec![98,99,100,101,102,103,104]],vec![vec![175,176,177,178,179,180,181],vec![182,183,184,185,186,187,188],vec![189,190,191,192,193,194,195],vec![196,197,198,199,200,201,202],vec![203,204,205,206,207,208,209]]]];
    let data2 = WArr::from_vec4_f64(data2)?;
    assert_eq!(t, data2);
    Ok(())
}



#[test]
fn test_vec4_transpose3() -> WResult<()> {
    let data = (0..240).collect::<Vec<_>>();
    let data: WArr<F32> = WArr::from_vec1_f64(data)?.reshape((8, 2, 3, 5))?;
    let t = data.transpose(0, 1)?;
    let data2 = vec![vec![vec![vec![0,1,2,3,4],vec![5,6,7,8,9],vec![10,11,12,13,14]],vec![vec![30,31,32,33,34],vec![35,36,37,38,39],vec![40,41,42,43,44]],vec![vec![60,61,62,63,64],vec![65,66,67,68,69],vec![70,71,72,73,74]],vec![vec![90,91,92,93,94],vec![95,96,97,98,99],vec![100,101,102,103,104]],vec![vec![120,121,122,123,124],vec![125,126,127,128,129],vec![130,131,132,133,134]],vec![vec![150,151,152,153,154],vec![155,156,157,158,159],vec![160,161,162,163,164]],vec![vec![180,181,182,183,184],vec![185,186,187,188,189],vec![190,191,192,193,194]],vec![vec![210,211,212,213,214],vec![215,216,217,218,219],vec![220,221,222,223,224]]],vec![vec![vec![15,16,17,18,19],vec![20,21,22,23,24],vec![25,26,27,28,29]],vec![vec![45,46,47,48,49],vec![50,51,52,53,54],vec![55,56,57,58,59]],vec![vec![75,76,77,78,79],vec![80,81,82,83,84],vec![85,86,87,88,89]],vec![vec![105,106,107,108,109],vec![110,111,112,113,114],vec![115,116,117,118,119]],vec![vec![135,136,137,138,139],vec![140,141,142,143,144],vec![145,146,147,148,149]],vec![vec![165,166,167,168,169],vec![170,171,172,173,174],vec![175,176,177,178,179]],vec![vec![195,196,197,198,199],vec![200,201,202,203,204],vec![205,206,207,208,209]],vec![vec![225,226,227,228,229],vec![230,231,232,233,234],vec![235,236,237,238,239]]]];
    let data2 = WArr::from_vec4_f64(data2)?;
    assert_eq!(t, data2);
    Ok(())
}


#[test]
fn test_vec4_transpose4() -> WResult<()> {
    let data = (1..7).collect::<Vec<_>>();
    let data: WArr<F32> = WArr::from_vec1_f64(data)?.reshape((2, 3))?;
    let t = data.t()?;
    let data2 = vec![
        vec![1, 4],
        vec![2, 5],
        vec![3, 6],
    ];
    let data2 = WArr::from_vec2_f64(data2)?;
    assert_eq!(t, data2);
    Ok(())
}


#[test]
fn test_vec_concat() -> WResult<()> {
    let data1 = (0..120).collect::<Vec<_>>();
    let data2 = (200..360).collect::<Vec<_>>();
    let data1: WArr<F32> = WArr::from_vec1_f64(data1)?.reshape((2, 4, 3, 5))?;
    let data2 = WArr::from_vec1_f64(data2)?.reshape((2, 4, 4, 5))?;
    let t = data1.concat(&data2, 2)?;
    let res = vec![vec![vec![vec![0,1,2,3,4],vec![5,6,7,8,9],vec![10,11,12,13,14],vec![200,201,202,203,204],vec![205,206,207,208,209],vec![210,211,212,213,214],vec![215,216,217,218,219]],vec![vec![15,16,17,18,19],vec![20,21,22,23,24],vec![25,26,27,28,29],vec![220,221,222,223,224],vec![225,226,227,228,229],vec![230,231,232,233,234],vec![235,236,237,238,239]],vec![vec![30,31,32,33,34],vec![35,36,37,38,39],vec![40,41,42,43,44],vec![240,241,242,243,244],vec![245,246,247,248,249],vec![250,251,252,253,254],vec![255,256,257,258,259]],vec![vec![45,46,47,48,49],vec![50,51,52,53,54],vec![55,56,57,58,59],vec![260,261,262,263,264],vec![265,266,267,268,269],vec![270,271,272,273,274],vec![275,276,277,278,279]]],vec![vec![vec![60,61,62,63,64],vec![65,66,67,68,69],vec![70,71,72,73,74],vec![280,281,282,283,284],vec![285,286,287,288,289],vec![290,291,292,293,294],vec![295,296,297,298,299]],vec![vec![75,76,77,78,79],vec![80,81,82,83,84],vec![85,86,87,88,89],vec![300,301,302,303,304],vec![305,306,307,308,309],vec![310,311,312,313,314],vec![315,316,317,318,319]],vec![vec![90,91,92,93,94],vec![95,96,97,98,99],vec![100,101,102,103,104],vec![320,321,322,323,324],vec![325,326,327,328,329],vec![330,331,332,333,334],vec![335,336,337,338,339]],vec![vec![105,106,107,108,109],vec![110,111,112,113,114],vec![115,116,117,118,119],vec![340,341,342,343,344],vec![345,346,347,348,349],vec![350,351,352,353,354],vec![355,356,357,358,359]]]];
    let res = WArr::from_vec4_f64(res)?;
    assert_eq!(t, res);
    Ok(())
}



#[test]
fn test_vec_stack() -> WResult<()> {
    let data1 = (0..120).collect::<Vec<_>>();
    let data2 = (200..320).collect::<Vec<_>>();
    let data1: WArr<F32> = WArr::from_vec1_f64(data1)?.reshape((2, 4, 3, 5))?;
    let data2 = WArr::from_vec1_f64(data2)?.reshape((2, 4, 3, 5))?;
    let t = data1.stack(&data2, 2)?;
    let res = vec![vec![vec![vec![vec![0,1,2,3,4],vec![5,6,7,8,9],vec![10,11,12,13,14]],vec![vec![200,201,202,203,204],vec![205,206,207,208,209],vec![210,211,212,213,214]]],vec![vec![vec![15,16,17,18,19],vec![20,21,22,23,24],vec![25,26,27,28,29]],vec![vec![215,216,217,218,219],vec![220,221,222,223,224],vec![225,226,227,228,229]]],vec![vec![vec![30,31,32,33,34],vec![35,36,37,38,39],vec![40,41,42,43,44]],vec![vec![230,231,232,233,234],vec![235,236,237,238,239],vec![240,241,242,243,244]]],vec![vec![vec![45,46,47,48,49],vec![50,51,52,53,54],vec![55,56,57,58,59]],vec![vec![245,246,247,248,249],vec![250,251,252,253,254],vec![255,256,257,258,259]]]],vec![vec![vec![vec![60,61,62,63,64],vec![65,66,67,68,69],vec![70,71,72,73,74]],vec![vec![260,261,262,263,264],vec![265,266,267,268,269],vec![270,271,272,273,274]]],vec![vec![vec![75,76,77,78,79],vec![80,81,82,83,84],vec![85,86,87,88,89]],vec![vec![275,276,277,278,279],vec![280,281,282,283,284],vec![285,286,287,288,289]]],vec![vec![vec![90,91,92,93,94],vec![95,96,97,98,99],vec![100,101,102,103,104]],vec![vec![290,291,292,293,294],vec![295,296,297,298,299],vec![300,301,302,303,304]]],vec![vec![vec![105,106,107,108,109],vec![110,111,112,113,114],vec![115,116,117,118,119]],vec![vec![305,306,307,308,309],vec![310,311,312,313,314],vec![315,316,317,318,319]]]]];
    let res = WArr::from_vec5_f64(res)?;
    assert_eq!(t, res);
    Ok(())
}



#[test]
fn test_arr_sum() -> WResult<()> {
    let data1 = (0..120).collect::<Vec<_>>();
    let data1: WArr<F32> = WArr::from_vec1_f64(data1)?.reshape((2, 4, 3, 5))?;
    let t = data1.sum(2)?;
    let res = vec![vec![vec![vec![15,18,21,24,27]],vec![vec![60,63,66,69,72]],vec![vec![105,108,111,114,117]],vec![vec![150,153,156,159,162]]],vec![vec![vec![195,198,201,204,207]],vec![vec![240,243,246,249,252]],vec![vec![285,288,291,294,297]],vec![vec![330,333,336,339,342]]]];
    let res = WArr::from_vec4_f64(res)?;
    assert_eq!(t, res);
    Ok(())
}


#[test]
fn test_arr_broadcast1() -> WResult<()> {
    let data1 = (0..40).collect::<Vec<_>>();
    let data1: WArr<F32> = WArr::from_vec1_f64(data1)?.reshape((2, 4, 1, 5))?;
    let t = data1.broadcast(2, 3)?;
    let res = vec![vec![vec![vec![0,1,2,3,4],vec![0,1,2,3,4],vec![0,1,2,3,4]],vec![vec![5,6,7,8,9],vec![5,6,7,8,9],vec![5,6,7,8,9]],vec![vec![10,11,12,13,14],vec![10,11,12,13,14],vec![10,11,12,13,14]],vec![vec![15,16,17,18,19],vec![15,16,17,18,19],vec![15,16,17,18,19]]],vec![vec![vec![20,21,22,23,24],vec![20,21,22,23,24],vec![20,21,22,23,24]],vec![vec![25,26,27,28,29],vec![25,26,27,28,29],vec![25,26,27,28,29]],vec![vec![30,31,32,33,34],vec![30,31,32,33,34],vec![30,31,32,33,34]],vec![vec![35,36,37,38,39],vec![35,36,37,38,39],vec![35,36,37,38,39]]]];
    let res = WArr::from_vec4_f64(res)?;
    assert_eq!(t, res);
    Ok(())
}


#[test]
fn test_arr_concat1() -> WResult<()> {
    let data1 = (0..40).collect::<Vec<_>>();
    let data1: WArr<F32> = WArr::from_vec1_f64(data1)?.reshape((2, 4, 5))?;
    let t = data1.concat(&data1, 1)?;
    let res = vec![vec![vec![0,1,2,3,4],vec![5,6,7,8,9],vec![10,11,12,13,14],vec![15,16,17,18,19],vec![0,1,2,3,4],vec![5,6,7,8,9],vec![10,11,12,13,14],vec![15,16,17,18,19]],vec![vec![20,21,22,23,24],vec![25,26,27,28,29],vec![30,31,32,33,34],vec![35,36,37,38,39],vec![20,21,22,23,24],vec![25,26,27,28,29],vec![30,31,32,33,34],vec![35,36,37,38,39]]];
    let res = WArr::from_vec3_f64(res)?;
    assert_eq!(t, res);
    Ok(())
}



#[test]
fn test_conv2d_normal1() -> WResult<()> {
    let t: WArr<F32> = WArr::from_vec1_f64(vec![-1.9, 2.1, 0.3, -0.5, 1.9, -1.6, -1.8, 0.3, -0.9, -0.4, -0.2, 1.4, -0.1, 0.8, 0.8, -0.9, -0.5, -1.0, -1.0, -0.4, 1.3, 0.5, 0.6, -0.6, 0.1, -0.4, 0.4, -0.1, 0.2, 1.5, -1.0, -0.8, 0.8, 0.9, 0.2, -1.5, 0.7, 0.3, 0.2, -0.4, 1.7, -0.5, 0.6, 0.8, -0.0, -0.3, 0.5, -0.8, -0.1, -0.2, 0.9, 0.6, -0.7, 0.0, -1.4, 0.1, -0.2, -0.6, 0.2, -0.3, -0.4, 1.6, 0.7, 2.7, -0.4, -0.2, 0.8, -1.6, -0.8, -0.6, -0.8, -1.4, 0.4, 0.4, 0.5, -0.5, -0.5, 1.0, -1.2, 2.5, 1.6, -0.1, 3.0, -0.4, 2.8, 0.2, 1.2, -1.3, 0.1, -0.1, 1.3, 1.3, 0.1, -0.6, -1.4, -2.6, 0.5, 0.7, 0.7, 0.4])?
        .reshape((1, 4, 5, 5))?;
    let w = WArr::from_vec1_f64(vec![2.0, -0.8, 1.0, -2.4, -1.2, -1.0, -0.4, -0.6, -0.5, 0.6, 0.2, 0.6, -0.3, -1.9, 0.9, -1.0, -1.9, -0.6, 1.7, -0.3, 0.8, 0.2, -0.4, 0.9, -0.3, 0.8, 0.2, 1.5, 0.0, 0.6, 0.5, 0.3, -0.3, 0.5, -0.5, 0.2, -1.2, -0.1, -0.1, 1.9, -2.1, -0.0, 0.6, 0.8, 1.1, 1.4, -0.8, 0.6, -0.7, -0.3, -0.5, -0.3, 0.4, -0.0, -0.3, 0.4, 0.3, -0.5, -0.0, -0.3, 0.7, 0.7, 0.9, 0.9, 0.8, -1.3, 0.9, 0.1, -0.0, -0.0, -0.8, -0.2])?
        .reshape((2, 4, 3, 3))?;

    let res = t.conv2d(&w, 1, 0, 1, 1)?.round(2.0)?;
    let res_real = WArr::from_vec1_f64(vec![2.9700000286102295, 8.899999618530273, 3.4000000953674316, 1.7100000381469727, -7.920000076293945, 1.190000057220459, -2.8299999237060547, 16.559999465942383, 1.600000023841858, 3.940000057220459, 2.390000104904175, 4.570000171661377, -8.210000038146973, 2.190000057220459, -7.289999961853027, 1.7200000286102295, 0.18000000715255737, 0.8799999952316284])?
        .round(2.0)?
        .reshape((1, 2, 3, 3))?;
    assert_eq!(res, res_real, "conv2d 1 0 1 1");

    let res = t.conv2d(&w, 2, 2, 2, 1)?.round(2.0)?;
    let res_real = WArr::from_vec1_f64(vec![3.509999990463257, 4.699999809265137, -5.260000228881836, 7.0, -1.690000057220459, 0.9300000071525574, -2.7799999713897705, -5.159999847412109, -1.9700000286102295, 3.9000000953674316, -2.9000000953674316, -1.6799999475479126, 2.7200000286102295, 0.6399999856948853, -1.9199999570846558, 0.6899999976158142, -3.119999885559082, 1.1799999475479126])?
        .round(2.0)?
        .reshape((1, 2, 3, 3))?;
    assert_eq!(res, res_real, "conv2d 2 2 2 1");
    
    
    Ok(())
}


#[test]
fn test_conv2d_normal2() -> WResult<()> {
    let t: WArr<F32> = WArr::from_vec1_f64(vec![0.20, 0.80, -0.20, 0.30, -1.10, 0.40, 0.20, 0.00, -0.20, -1.20, -0.80, 1.30, 0.20, 1.70, -0.70, 1.60, 3.20, -0.40, -1.20, 1.20, -0.70, -1.20, -1.30, -0.90, -1.50, -1.10, -0.70, -0.40, -0.60, 0.50, -0.30, 0.80, 1.50, 0.30, 1.20, 1.50, 1.10, -0.00, -1.00, -1.80, -1.20, 0.50, 0.10, -1.50, -1.80, 0.50, -1.80, -1.90, 1.70, -0.00, -0.50, -0.00, 1.00, -0.60, 1.50, -0.50, 0.20, 0.10, -1.20, 1.90, 2.20, 1.10, -1.10, 0.50, -0.40, 1.10, -0.60, -1.90, -1.60, 1.40, -0.90, -0.40, 0.40, -1.80, -0.50, 0.20, 1.40, 0.20, 0.60, -0.40, -1.70, -0.60, 0.40, 0.70, 0.40, -1.50, 0.30, 0.70, 1.50, 0.70, 3.30, 0.70, -1.40, 0.10, -1.10, -0.90, 0.10, 1.20, 0.90, 0.20, 0.90, 0.90, 0.80, -0.70, -1.50, -1.40, -0.50, 1.60, -1.20, -1.70, 0.40, -0.30, -0.30, -0.00, -0.30, 1.30, -0.50, 1.00, -0.50, -0.30, 0.60, 1.30, -0.40, -2.30, 0.40, -0.40, 0.40, 0.40, 0.60, 1.40, 1.40, 0.50, -1.50, 1.20, -1.10, -0.10, 0.50, -1.00, -0.30, -0.80, 0.80, 0.60, 1.00, -1.00, 0.30, 0.30, 0.80, -0.40, -1.00, 0.50, 1.10, 1.20, -0.10, -1.10, 0.80, -0.90, 1.20, -1.60, 0.40, 0.60, -0.10, -2.00, -1.00, -0.30, 0.10, 0.20, 2.00, -0.20, -0.70, 1.40, -0.40, 0.40, 0.30, 0.00, -0.40, 0.40, -0.80, 1.20, 1.20, 0.50, -0.30, -1.30, -0.10, 2.00, -0.30, -0.60, -0.60, 1.00, 0.50, -0.60, 0.80, -0.30, -0.60, 0.20, -1.70, 1.10, 0.10, -0.00, -0.40, -2.20, 1.30, 0.30, -0.20, -0.60, 0.00, -0.40, -0.20, 0.30, 0.60, -0.40, 0.50, -1.10, -1.60, 0.60, -2.50, -0.30, 0.50, 1.60, 1.80, -1.30, 0.60, -0.40, 1.30, 1.00, -0.20, -0.00, -0.90, 0.30, -0.60, 1.30, -0.20, -0.20, 0.10, 1.30, 0.10, -0.60, 1.20, -0.90, -1.30, 0.50, 1.50, 0.70, 1.40, 0.70, -0.60, 1.70, 2.10, 1.20, -0.60, 0.30, 1.30, 0.90, -0.40, -0.40, -0.80, 0.70, -0.50, -0.10, 0.20, 1.10, 2.50, -0.60, 0.50, 0.10, 0.40, -1.60, -0.70, -1.80, 0.90, 0.70, -0.50, 0.20, 0.40, 0.70, 0.10, 0.90, 0.70, -1.70, -0.80, -0.20, 1.10, -2.10, 0.90, 1.80, 1.00, -0.50, -0.80, -0.80, -0.90, -0.20, -0.10, -0.80, -0.60, -0.80, 1.30, 1.50, 0.70, -1.10, 0.30, -0.50, 0.30, -0.50, -0.20, -0.70, 1.50, 1.10, -1.00, 1.60, 1.30, 1.70, -0.30, -0.10, 0.50, 0.80, 2.20, -0.60, 0.00, -0.30, -0.20, -1.10, -0.50, -1.50, -2.70, -2.30, -0.30, 0.50, 0.30, -0.50, -0.80, -1.40, -0.60, -0.40, 1.10, 0.20, 0.40, -1.40, 0.00, -0.40, -0.60, -1.90, -1.00, -0.20, -1.20, 0.10, -1.40, -0.90, 0.80, -0.40, 0.60, 1.90, 1.00, 1.30, -0.80, 0.30, 0.30, -1.80, -0.40, -1.20, -1.10, -1.70, 0.60, 0.70, 0.10, 0.70, 0.90, -0.90, 0.80, -0.40, -2.70, 0.60, -0.60, -1.20, 1.40, -0.20, -0.20, 0.30, -1.10, -1.30, -1.00, 1.10, 0.30, 0.20, 2.00, 0.90, 0.20, -0.10, 0.20, 1.90, -0.90, 0.50, -0.40, -0.30, -0.40, 0.60, 0.40, -0.80, 0.30, -0.30, 0.60, -1.80, 0.60, -1.20, -0.20, -1.30, -0.50, 0.10, 1.00, -0.30, -1.60, 1.60, -0.90, 2.00, 0.80, -1.10, 0.30, -0.90, 1.40, -0.80, -1.20, -0.70, -0.60, -0.10, -1.30, -2.00, 0.30, 0.50, -1.30, -0.40, 0.60, -0.00, 0.10, 0.50, -0.80, -2.30, 0.60, -0.20, 1.20, 0.20, 1.60, -1.60, 1.30, -0.00, 1.60, 0.90, 0.60, 0.00, -1.70, 1.10, 0.30, 0.20, 0.50, 0.60, -0.30, -0.20, 1.70, -0.00, -0.50, 2.20, -1.30, 0.70, 0.90, 0.80, 1.40, -0.40, -0.60, -1.10, 0.40, -0.50, -1.00, -0.10, 0.40, 0.10, 1.60, 1.50, -0.30, 0.30, -0.20, -1.50, -0.70, 0.20, 0.50, -0.30, -1.20, 0.30, -0.70, 1.50, 1.00, 0.90, -0.90, 0.10, -1.20, -0.60, 0.90, 0.40, -1.70, -0.30, 0.70, -1.20, -0.50, 1.00, -1.00, 0.10, 1.30, 1.00, -0.80, 0.30, -0.40, -1.40, -1.40, 2.70, 1.50, -1.20, -0.20, -0.10, 0.60, 1.10, -0.90, 0.30, 0.40, 1.40, 1.20, 1.10, 2.80, -1.20, -0.30, -0.60, 0.80, -0.30, -0.20, -0.60, 0.80, -0.80, 1.20, 1.70, -0.60, 1.60, -0.50, 1.20, -0.80, 0.70, -1.20, 0.80, -0.50, 0.60, -0.10, -1.20, 0.00, 1.20, -2.30, 0.20, 0.10, 1.50, -1.30, -0.10, 0.10, 0.10, 0.60, -0.80, 0.00, 1.50, 0.60, -1.20, 1.00, 0.20, 0.30, -0.10, -0.00, -0.40, 0.10, 0.50, -1.00, -0.40, 1.10, 0.60, -0.50, 0.20, -1.50, -0.20, 0.60, -0.60, -0.50, 2.50, -0.80, 0.90, 1.00, 1.20, -0.30, -1.10])?
        .reshape((3, 4, 7, 7))?;
    let w = WArr::from_vec1_f64(vec![0.30, -0.90, 0.40, -1.80, -1.20, 0.20, 1.30, -2.00, 0.40, 0.80, -1.30, -1.30, -1.80, -0.20, 1.00, 0.50, -0.80, 1.90, 0.60, 0.00, -1.10, -0.90, -0.80, 0.20, 1.60, -0.10, 1.70, 0.50, 1.20, -2.50, 0.60, -1.40, 1.80, -1.10, -1.20, -0.20, -1.60, 0.70, 1.20, -2.20, 0.40, 0.10, 1.30, 0.10, 1.00, 0.10, 0.30, 0.30, -0.90, -0.90, -1.40, 0.70, 0.90, -0.70, 0.30, -0.90, -1.30, 0.60, -0.30, 0.60, 0.00, -0.70, -1.70, -0.40, -1.20, -0.30, 0.70, -0.30, -0.50, 1.40, 1.20, 0.70])?
        .reshape((2, 4, 3, 3))?;

    let res = t.conv2d(&w, 1, 0, 1, 1)?.round(2.0)?;
    
    let res_real = WArr::from_vec1_f64(vec![3.41, -10.27, 7.77, 7.28, -9.76, -2.71, -5.94, -14.82, -2.69, 2.77, -0.73, -2.53, 4.55, 9.64, 8.65, 4.30, -3.55, 8.67, 0.56, -3.68, -6.04, 1.30, 1.18, 9.07, 17.04, -0.34, -7.30, 0.29, -6.25, -3.10, -10.05, -4.16, 3.71, 7.03, -0.54, 10.15, 5.15, -6.04, 3.97, 2.19, 1.74, -0.64, 4.30, 2.39, -2.60, -6.24, -5.71, -12.17, -10.09, -7.81, 0.38, 14.61, 3.04, -6.79, -7.40, -3.16, 3.58, -6.85, 14.74, -3.52, -0.27, -0.94, -15.19, 8.59, -7.50, -1.88, 3.37, 6.16, -6.51, -7.94, -4.61, -5.36, -4.91, -1.44, -6.76, 1.24, -0.88, -4.88, 7.24, -3.67, -4.90, 0.14, -0.05, -0.88, 9.04, -2.46, 1.04, 2.31, -6.07, -4.69, -5.88, -5.89, 10.86, 12.85, 16.25, -0.35, 5.10, 0.72, -0.19, 9.84, -6.67, 0.61, -4.85, 15.68, -0.38, 0.06, 2.49, -7.30, 2.46, 4.26, 2.44, 8.30, 17.85, -11.15, -5.29, -5.26, 4.03, 3.41, -1.04, -5.57, -3.25, 7.42, -3.20, -8.96, 3.69, 4.17, 1.00, 3.47, 0.97, -3.72, -6.78, 0.81, 0.03, -4.74, -4.29, -6.35, -1.66, 1.98, -5.41, 5.74, -0.72, -5.19, -1.07, -5.00, -1.44, 3.34, -4.92, 4.43, 6.56, 3.09])?
        .round(2.0)?
        .reshape((3, 2, 5, 5))?;
    assert_eq!(res, res_real, "conv2d 1 0 1 1");

    let res = t.conv2d(&w, 2, 3, 5, 1)?.round(2.0)?;
    let res_real = WArr::from_vec1_f64(vec![-3.40, 1.94, -1.72, -3.50, 0.47, 1.32, -0.30, 0.57, 4.66, 1.60, 1.42, 0.40, 0.50, -1.45, -1.02, 1.11, 1.66, 3.18, -0.12, -2.34, -2.53, -0.33, -1.12, -1.95])?
        .round(2.0)?
        .reshape((3, 2, 2, 2))?;
    assert_eq!(res, res_real, "conv2d 2, 3, 5, 1");

    Ok(())
}


#[test]
fn test_conv2d_group() -> WResult<()> {
    let t: WArr<F32> = WArr::from_vec1_f64(vec![0.8, 1.1, -2.6, -1.6, 0.4, 0.2, -1.4, -0.3, 0.3, -1.6, 1.1, 0.5, 1.0, -0.7, 0.8, -0.9, -0.8, 1.6, 1.4, -1.1, -1.5, 0.1, -0.8, 1.0, 1.7, -0.5, -1.1, -0.4, -0.9, -0.5, -1.2, -0.1, 0.2, 1.2, -0.9, 1.1, 0.1, -1.1, -1.1, -0.9, 0.4, -2.3, -0.6, -0.3, 0.3, 0.3, -0.1, 0.8, -0.7, -0.4, 0.7, 1.3, -0.7, 0.3, 0.4, -1.7, -0.5, 0.2, -0.3, 0.6, -0.1, 1.4, -1.1, -1.4, -0.7, -0.3, 0.1, -1.0, -2.0, 1.7, 1.1, -0.1, 0.1, -0.5, -1.3, -1.3, -1.4, -0.8, 0.5, 1.5, -0.4, -0.3, -0.6, -0.6, 1.6, -0.7, 0.7, 2.1, 0.5, -0.5, -0.9, 0.7, -1.6, 0.2, 1.0, -0.7, -1.2, -0.8, -2.1, -1.0, -0.1, 0.9, -0.0, -0.3, 1.5, 0.2, -0.1, -1.4, -0.1, 1.4, -1.6, 1.6, -0.7, 1.8, 0.4, -0.5, 1.5, -2.0, -0.3, 0.5, -1.1, 0.3, 0.3, -1.2, 0.4, 0.3, 1.0, 0.5, 2.8, 0.3, 1.3, 1.2, 0.7, 1.8, 0.9, 0.1, -0.6, -0.6, 0.3, 1.0, 1.6, 1.0, -0.2, 1.1, -1.2, -0.0, 0.1, -0.0, -0.6, -0.2])?
        .reshape((1, 6, 5, 5))?;
    let w = WArr::from_vec1_f64(vec![0.3, 0.9, -1.8, -0.2, 1.3, -0.6, 0.7, 0.6, 1.8, 0.0, 0.3, -0.0, -0.1, -1.3, 0.2, -1.0, 0.2, 1.0, -0.6, -0.4, -1.3, 0.6, 1.2, 0.8, -2.6, 0.3, 1.1, -0.1, 1.5, -0.9, -1.2, 0.5, 0.3, -1.6, 1.0, -1.5, 0.0, 0.2, -0.3, 0.3, 2.8, 0.7, -1.7, 0.1, 2.0, 1.3, -0.3, -0.3, 0.8, -0.7, -0.2, -2.1, -0.9, 0.9])?
        .reshape((3, 2, 3, 3))?;
    
    let res = t.conv2d(&w, 2, 2, 2, 3)?.round(2.0)?;
    let res_real = WArr::from_vec1_f64(vec![4.75, -2.72, 3.83, 2.95, -0.21, 1.13, -2.18, -3.45, 3.56, -5.70, 4.64, 0.56, -0.38, -5.72, -2.05, -0.66, 4.96, -1.06, -2.78, 5.53, 5.98, -4.69, 0.36, 0.30, -2.83, 0.58, 0.35])?
        .round(2.0)?
        .reshape((1, 3, 3, 3))?;
    assert_eq!(res, res_real, "conv2d 2 2 2 2");
    
    Ok(())
}



#[test]
fn test_flipped_1() -> WResult<()> {
    let w = F32::new(vec![-0.9325, 0.6451, -0.8537, 0.2378, 0.8764, -0.1832, 0.2987, -0.6488, -0.2273, -2.4184, -0.1192, -0.4821, -0.5079, -0.5766, -2.4729, 1.6734, 0.4558, 0.2851, 1.1514, -0.9013, 1.0662, -0.1817, -0.0259, 0.1709, 0.5367, 0.7513, 0.8086, -2.2586, -0.5027, 0.9141, -1.3086, -1.3343, -1.5669, -0.1657, 0.7958, 0.1432, 0.3896, -0.4501, 0.1667, 0.0714, -0.0952, 1.2970, -0.1674, -0.3178, 1.0677, 0.3060, 0.7080, 0.1914, 1.1679, -0.3602, 1.9265, -1.8626, -0.5112, -0.0982, 0.2621, 0.6565, 0.5908, 1.0089, -0.1646, 1.8032, -0.6286, 0.2016, -0.3370, 1.2555, 0.8009, -0.6488, -0.4652, -1.5685, 1.5860, 0.5583, 0.4623, 0.6026])?;
    let dims = [2, 4, 3, 3];
    
    let w_flipped = w.flipped(&dims, &[2, 3])?;

    let w_real = F32::new(vec![-0.2273, -0.6488, 0.2987, -0.1832, 0.8764, 0.2378, -0.8537, 0.6451, -0.9325, 0.2851, 0.4558, 1.6734, -2.4729, -0.5766, -0.5079, -0.4821, -0.1192, -2.4184, 0.8086, 0.7513, 0.5367, 0.1709, -0.0259, -0.1817, 1.0662, -0.9013, 1.1514, 0.1432, 0.7958, -0.1657, -1.5669, -1.3343, -1.3086, 0.9141, -0.5027, -2.2586, 1.0677, -0.3178, -0.1674, 1.2970, -0.0952, 0.0714, 0.1667, -0.4501, 0.3896, -0.0982, -0.5112, -1.8626, 1.9265, -0.3602, 1.1679, 0.1914, 0.7080, 0.3060, -0.3370, 0.2016, -0.6286, 1.8032, -0.1646, 1.0089, 0.5908, 0.6565, 0.2621, 0.6026, 0.4623, 0.5583, 1.5860, -1.5685, -0.4652, -0.6488, 0.8009, 1.2555])?;
    
    
    assert_eq!(w_flipped.len(), w_real.len(), "flipped 2 3 len");
    assert_eq!(w_flipped, w_real, "flipped 2 3 data");

    Ok(())
}



#[test]
fn test_conv_transpose2d_1() -> WResult<()> {
    let t: WArr<F32> = WArr::from_vec1_f64(vec![
        0.4056f32, -0.8689, -0.0773, -1.5630, -2.8012, -1.5059, 0.3972, 1.0852, 0.4997, 3.0616,
        1.6541, 0.0964, -0.8338, -1.6523, -0.8323, -0.1699, 0.0823, 0.3526, 0.6843, 0.2395,
        1.2279, -0.9287, -1.7030, 0.1370, 0.6047, 0.3770, -0.6266, 0.3529, 2.2013, -0.6836,
        0.2477, 1.3127, -0.2260, 0.2622, -1.2974, -0.8140, -0.8404, -0.3490, 0.0130, 1.3123,
        1.7569, -0.3956, -1.8255, 0.1727, -0.3538, 2.6941, 1.0529, 0.4219, -0.2071, 1.1586,
        0.4717, 0.3865, -0.5690, -0.5010, -0.1310, 0.7796, 0.6630, -0.2021, 2.6090, 0.2049,
        0.6466, -0.5042, -0.0603, -1.6538, -1.2429, 1.8357, 1.6052, -1.3844, 0.3323, -1.3712,
        0.9634, -0.4799, -0.6451, -0.0840, -1.4247, 0.5512, -0.1747, -0.5509, -0.3742, 0.3790,
        -0.4431, -0.4720, -0.7890, 0.2620, 0.7875, 0.5377, -0.6779, -0.8088, 1.9098, 1.2006,
        -0.8, -0.4983, 1.5480, 0.8265, -0.1025, 0.5138, 0.5748, 0.3821, -0.4607, 0.0085,
    ])?.reshape((1, 4, 5, 5))?;
    let w = WArr::from_vec1_f64(vec![
        -0.9325f32, 0.6451, -0.8537, 0.2378, 0.8764, -0.1832, 0.2987, -0.6488, -0.2273,
        -2.4184, -0.1192, -0.4821, -0.5079, -0.5766, -2.4729, 1.6734, 0.4558, 0.2851, 1.1514,
        -0.9013, 1.0662, -0.1817, -0.0259, 0.1709, 0.5367, 0.7513, 0.8086, -2.2586, -0.5027,
        0.9141, -1.3086, -1.3343, -1.5669, -0.1657, 0.7958, 0.1432, 0.3896, -0.4501, 0.1667,
        0.0714, -0.0952, 1.2970, -0.1674, -0.3178, 1.0677, 0.3060, 0.7080, 0.1914, 1.1679,
        -0.3602, 1.9265, -1.8626, -0.5112, -0.0982, 0.2621, 0.6565, 0.5908, 1.0089, -0.1646,
        1.8032, -0.6286, 0.2016, -0.3370, 1.2555, 0.8009, -0.6488, -0.4652, -1.5685, 1.5860,
        0.5583, 0.4623, 0.6026,
    ])?.reshape((2, 4, 3, 3))?;
    
    let w = w.transpose(0, 1)?;
    let res = t.conv_transpose2d(
        &w,
        1,
        0,
        0,
        1,
        1
    )?.round(4.0)?;
    let res_real = WArr::from_vec1_f64(vec![
        -1.9918, 2.6797, -0.4599, -1.6037, 1.4131, -2.4012, 2.9277,
        1.8016, -3.5361, 1.0757, 3.5395, -8.2168, -3.2023, 0.5375,
        0.8243, 1.8675, 7.8929, -4.0746, -6.4415, 5.1139, 1.6889,
        0.2722, 8.9679, 3.3477, 1.8514, -4.2896, -3.8228, -7.5632,
        -8.5412, -5.8142, -7.1587, -1.6095, 0.4651, 0.2748, -2.0985,
        2.0833, -0.6482, -12.1692, -4.1284, -2.9765, -0.0656, -4.5114,
        5.307, 2.6957, 2.3087, 1.0478, 0.7808, -1.1519, -0.9579,
        1.089, 0.1872, -0.6408, -0.9897, 0.8503, 1.1019, -0.9211,
        -0.1741, -0.2915, 4.2472, 1.9417, 1.65, 0.6303, -4.7131,
        1.6555, 2.4026, -2.9293, 2.9953, 0.5328, 3.5873, -0.9621,
        -1.4289, -3.2787, 4.1747, -6.0341, -4.6341, -5.7945, 4.142,
        7.5973, 6.4431, 5.9872, 2.1639, -8.6566, 3.3143, -3.4059,
        -0.8775, -3.048, 11.6543, 0.6442, 2.3218, -0.4765, 1.1516,
        -5.5423, -2.5188, 1.0754, -0.0563, -2.9386, -1.1504, 1.0171
    ])?.reshape((1, 2, 7, 7))?;

    assert_eq!(res, res_real, "conv transpose 2d data");

    Ok(())
}